import os
from pathlib import Path
from typing import Optional, Tuple, Union

import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio._internal import download_url_to_file
from torchaudio.datasets.utils import _extract_tar

# The following lists prefixed with `filtered_` provide a filtered split
# that:
#
# a. Mitigate a known issue with GTZAN (duplication)
#
# b. Provide a standard split for testing it against other
#    methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
#
# Those are used when GTZAN is initialised with the `filtered` keyword.
# The split was taken from (github) jordipons/sklearn-audio-transfer-learning.

gtzan_genres = [
    "blues",
    "classical",
    "country",
    "disco",
    "hiphop",
    "jazz",
    "metal",
    "pop",
    "reggae",
    "rock",
]

filtered_test = [
    "blues.00012",
    "blues.00013",
    "blues.00014",
    "blues.00015",
    "blues.00016",
    "blues.00017",
    "blues.00018",
    "blues.00019",
    "blues.00020",
    "blues.00021",
    "blues.00022",
    "blues.00023",
    "blues.00024",
    "blues.00025",
    "blues.00026",
    "blues.00027",
    "blues.00028",
    "blues.00061",
    "blues.00062",
    "blues.00063",
    "blues.00064",
    "blues.00065",
    "blues.00066",
    "blues.00067",
    "blues.00068",
    "blues.00069",
    "blues.00070",
    "blues.00071",
    "blues.00072",
    "blues.00098",
    "blues.00099",
    "classical.00011",
    "classical.00012",
    "classical.00013",
    "classical.00014",
    "classical.00015",
    "classical.00016",
    "classical.00017",
    "classical.00018",
    "classical.00019",
    "classical.00020",
    "classical.00021",
    "classical.00022",
    "classical.00023",
    "classical.00024",
    "classical.00025",
    "classical.00026",
    "classical.00027",
    "classical.00028",
    "classical.00029",
    "classical.00034",
    "classical.00035",
    "classical.00036",
    "classical.00037",
    "classical.00038",
    "classical.00039",
    "classical.00040",
    "classical.00041",
    "classical.00049",
    "classical.00077",
    "classical.00078",
    "classical.00079",
    "country.00030",
    "country.00031",
    "country.00032",
    "country.00033",
    "country.00034",
    "country.00035",
    "country.00036",
    "country.00037",
    "country.00038",
    "country.00039",
    "country.00040",
    "country.00043",
    "country.00044",
    "country.00046",
    "country.00047",
    "country.00048",
    "country.00050",
    "country.00051",
    "country.00053",
    "country.00054",
    "country.00055",
    "country.00056",
    "country.00057",
    "country.00058",
    "country.00059",
    "country.00060",
    "country.00061",
    "country.00062",
    "country.00063",
    "country.00064",
    "disco.00001",
    "disco.00021",
    "disco.00058",
    "disco.00062",
    "disco.00063",
    "disco.00064",
    "disco.00065",
    "disco.00066",
    "disco.00069",
    "disco.00076",
    "disco.00077",
    "disco.00078",
    "disco.00079",
    "disco.00080",
    "disco.00081",
    "disco.00082",
    "disco.00083",
    "disco.00084",
    "disco.00085",
    "disco.00086",
    "disco.00087",
    "disco.00088",
    "disco.00091",
    "disco.00092",
    "disco.00093",
    "disco.00094",
    "disco.00096",
    "disco.00097",
    "disco.00099",
    "hiphop.00000",
    "hiphop.00026",
    "hiphop.00027",
    "hiphop.00030",
    "hiphop.00040",
    "hiphop.00043",
    "hiphop.00044",
    "hiphop.00045",
    "hiphop.00051",
    "hiphop.00052",
    "hiphop.00053",
    "hiphop.00054",
    "hiphop.00062",
    "hiphop.00063",
    "hiphop.00064",
    "hiphop.00065",
    "hiphop.00066",
    "hiphop.00067",
    "hiphop.00068",
    "hiphop.00069",
    "hiphop.00070",
    "hiphop.00071",
    "hiphop.00072",
    "hiphop.00073",
    "hiphop.00074",
    "hiphop.00075",
    "hiphop.00099",
    "jazz.00073",
    "jazz.00074",
    "jazz.00075",
    "jazz.00076",
    "jazz.00077",
    "jazz.00078",
    "jazz.00079",
    "jazz.00080",
    "jazz.00081",
    "jazz.00082",
    "jazz.00083",
    "jazz.00084",
    "jazz.00085",
    "jazz.00086",
    "jazz.00087",
    "jazz.00088",
    "jazz.00089",
    "jazz.00090",
    "jazz.00091",
    "jazz.00092",
    "jazz.00093",
    "jazz.00094",
    "jazz.00095",
    "jazz.00096",
    "jazz.00097",
    "jazz.00098",
    "jazz.00099",
    "metal.00012",
    "metal.00013",
    "metal.00014",
    "metal.00015",
    "metal.00022",
    "metal.00023",
    "metal.00025",
    "metal.00026",
    "metal.00027",
    "metal.00028",
    "metal.00029",
    "metal.00030",
    "metal.00031",
    "metal.00032",
    "metal.00033",
    "metal.00038",
    "metal.00039",
    "metal.00067",
    "metal.00070",
    "metal.00073",
    "metal.00074",
    "metal.00075",
    "metal.00078",
    "metal.00083",
    "metal.00085",
    "metal.00087",
    "metal.00088",
    "pop.00000",
    "pop.00001",
    "pop.00013",
    "pop.00014",
    "pop.00043",
    "pop.00063",
    "pop.00064",
    "pop.00065",
    "pop.00066",
    "pop.00069",
    "pop.00070",
    "pop.00071",
    "pop.00072",
    "pop.00073",
    "pop.00074",
    "pop.00075",
    "pop.00076",
    "pop.00077",
    "pop.00078",
    "pop.00079",
    "pop.00082",
    "pop.00088",
    "pop.00089",
    "pop.00090",
    "pop.00091",
    "pop.00092",
    "pop.00093",
    "pop.00094",
    "pop.00095",
    "pop.00096",
    "reggae.00034",
    "reggae.00035",
    "reggae.00036",
    "reggae.00037",
    "reggae.00038",
    "reggae.00039",
    "reggae.00040",
    "reggae.00046",
    "reggae.00047",
    "reggae.00048",
    "reggae.00052",
    "reggae.00053",
    "reggae.00064",
    "reggae.00065",
    "reggae.00066",
    "reggae.00067",
    "reggae.00068",
    "reggae.00071",
    "reggae.00079",
    "reggae.00082",
    "reggae.00083",
    "reggae.00084",
    "reggae.00087",
    "reggae.00088",
    "reggae.00089",
    "reggae.00090",
    "rock.00010",
    "rock.00011",
    "rock.00012",
    "rock.00013",
    "rock.00014",
    "rock.00015",
    "rock.00027",
    "rock.00028",
    "rock.00029",
    "rock.00030",
    "rock.00031",
    "rock.00032",
    "rock.00033",
    "rock.00034",
    "rock.00035",
    "rock.00036",
    "rock.00037",
    "rock.00039",
    "rock.00040",
    "rock.00041",
    "rock.00042",
    "rock.00043",
    "rock.00044",
    "rock.00045",
    "rock.00046",
    "rock.00047",
    "rock.00048",
    "rock.00086",
    "rock.00087",
    "rock.00088",
    "rock.00089",
    "rock.00090",
]

filtered_train = [
    "blues.00029",
    "blues.00030",
    "blues.00031",
    "blues.00032",
    "blues.00033",
    "blues.00034",
    "blues.00035",
    "blues.00036",
    "blues.00037",
    "blues.00038",
    "blues.00039",
    "blues.00040",
    "blues.00041",
    "blues.00042",
    "blues.00043",
    "blues.00044",
    "blues.00045",
    "blues.00046",
    "blues.00047",
    "blues.00048",
    "blues.00049",
    "blues.00073",
    "blues.00074",
    "blues.00075",
    "blues.00076",
    "blues.00077",
    "blues.00078",
    "blues.00079",
    "blues.00080",
    "blues.00081",
    "blues.00082",
    "blues.00083",
    "blues.00084",
    "blues.00085",
    "blues.00086",
    "blues.00087",
    "blues.00088",
    "blues.00089",
    "blues.00090",
    "blues.00091",
    "blues.00092",
    "blues.00093",
    "blues.00094",
    "blues.00095",
    "blues.00096",
    "blues.00097",
    "classical.00030",
    "classical.00031",
    "classical.00032",
    "classical.00033",
    "classical.00043",
    "classical.00044",
    "classical.00045",
    "classical.00046",
    "classical.00047",
    "classical.00048",
    "classical.00050",
    "classical.00051",
    "classical.00052",
    "classical.00053",
    "classical.00054",
    "classical.00055",
    "classical.00056",
    "classical.00057",
    "classical.00058",
    "classical.00059",
    "classical.00060",
    "classical.00061",
    "classical.00062",
    "classical.00063",
    "classical.00064",
    "classical.00065",
    "classical.00066",
    "classical.00067",
    "classical.00080",
    "classical.00081",
    "classical.00082",
    "classical.00083",
    "classical.00084",
    "classical.00085",
    "classical.00086",
    "classical.00087",
    "classical.00088",
    "classical.00089",
    "classical.00090",
    "classical.00091",
    "classical.00092",
    "classical.00093",
    "classical.00094",
    "classical.00095",
    "classical.00096",
    "classical.00097",
    "classical.00098",
    "classical.00099",
    "country.00019",
    "country.00020",
    "country.00021",
    "country.00022",
    "country.00023",
    "country.00024",
    "country.00025",
    "country.00026",
    "country.00028",
    "country.00029",
    "country.00065",
    "country.00066",
    "country.00067",
    "country.00068",
    "country.00069",
    "country.00070",
    "country.00071",
    "country.00072",
    "country.00073",
    "country.00074",
    "country.00075",
    "country.00076",
    "country.00077",
    "country.00078",
    "country.00079",
    "country.00080",
    "country.00081",
    "country.00082",
    "country.00083",
    "country.00084",
    "country.00085",
    "country.00086",
    "country.00087",
    "country.00088",
    "country.00089",
    "country.00090",
    "country.00091",
    "country.00092",
    "country.00093",
    "country.00094",
    "country.00095",
    "country.00096",
    "country.00097",
    "country.00098",
    "country.00099",
    "disco.00005",
    "disco.00015",
    "disco.00016",
    "disco.00017",
    "disco.00018",
    "disco.00019",
    "disco.00020",
    "disco.00022",
    "disco.00023",
    "disco.00024",
    "disco.00025",
    "disco.00026",
    "disco.00027",
    "disco.00028",
    "disco.00029",
    "disco.00030",
    "disco.00031",
    "disco.00032",
    "disco.00033",
    "disco.00034",
    "disco.00035",
    "disco.00036",
    "disco.00037",
    "disco.00039",
    "disco.00040",
    "disco.00041",
    "disco.00042",
    "disco.00043",
    "disco.00044",
    "disco.00045",
    "disco.00047",
    "disco.00049",
    "disco.00053",
    "disco.00054",
    "disco.00056",
    "disco.00057",
    "disco.00059",
    "disco.00061",
    "disco.00070",
    "disco.00073",
    "disco.00074",
    "disco.00089",
    "hiphop.00002",
    "hiphop.00003",
    "hiphop.00004",
    "hiphop.00005",
    "hiphop.00006",
    "hiphop.00007",
    "hiphop.00008",
    "hiphop.00009",
    "hiphop.00010",
    "hiphop.00011",
    "hiphop.00012",
    "hiphop.00013",
    "hiphop.00014",
    "hiphop.00015",
    "hiphop.00016",
    "hiphop.00017",
    "hiphop.00018",
    "hiphop.00019",
    "hiphop.00020",
    "hiphop.00021",
    "hiphop.00022",
    "hiphop.00023",
    "hiphop.00024",
    "hiphop.00025",
    "hiphop.00028",
    "hiphop.00029",
    "hiphop.00031",
    "hiphop.00032",
    "hiphop.00033",
    "hiphop.00034",
    "hiphop.00035",
    "hiphop.00036",
    "hiphop.00037",
    "hiphop.00038",
    "hiphop.00041",
    "hiphop.00042",
    "hiphop.00055",
    "hiphop.00056",
    "hiphop.00057",
    "hiphop.00058",
    "hiphop.00059",
    "hiphop.00060",
    "hiphop.00061",
    "hiphop.00077",
    "hiphop.00078",
    "hiphop.00079",
    "hiphop.00080",
    "jazz.00000",
    "jazz.00001",
    "jazz.00011",
    "jazz.00012",
    "jazz.00013",
    "jazz.00014",
    "jazz.00015",
    "jazz.00016",
    "jazz.00017",
    "jazz.00018",
    "jazz.00019",
    "jazz.00020",
    "jazz.00021",
    "jazz.00022",
    "jazz.00023",
    "jazz.00024",
    "jazz.00041",
    "jazz.00047",
    "jazz.00048",
    "jazz.00049",
    "jazz.00050",
    "jazz.00051",
    "jazz.00052",
    "jazz.00053",
    "jazz.00054",
    "jazz.00055",
    "jazz.00056",
    "jazz.00057",
    "jazz.00058",
    "jazz.00059",
    "jazz.00060",
    "jazz.00061",
    "jazz.00062",
    "jazz.00063",
    "jazz.00064",
    "jazz.00065",
    "jazz.00066",
    "jazz.00067",
    "jazz.00068",
    "jazz.00069",
    "jazz.00070",
    "jazz.00071",
    "jazz.00072",
    "metal.00002",
    "metal.00003",
    "metal.00005",
    "metal.00021",
    "metal.00024",
    "metal.00035",
    "metal.00046",
    "metal.00047",
    "metal.00048",
    "metal.00049",
    "metal.00050",
    "metal.00051",
    "metal.00052",
    "metal.00053",
    "metal.00054",
    "metal.00055",
    "metal.00056",
    "metal.00057",
    "metal.00059",
    "metal.00060",
    "metal.00061",
    "metal.00062",
    "metal.00063",
    "metal.00064",
    "metal.00065",
    "metal.00066",
    "metal.00069",
    "metal.00071",
    "metal.00072",
    "metal.00079",
    "metal.00080",
    "metal.00084",
    "metal.00086",
    "metal.00089",
    "metal.00090",
    "metal.00091",
    "metal.00092",
    "metal.00093",
    "metal.00094",
    "metal.00095",
    "metal.00096",
    "metal.00097",
    "metal.00098",
    "metal.00099",
    "pop.00002",
    "pop.00003",
    "pop.00004",
    "pop.00005",
    "pop.00006",
    "pop.00007",
    "pop.00008",
    "pop.00009",
    "pop.00011",
    "pop.00012",
    "pop.00016",
    "pop.00017",
    "pop.00018",
    "pop.00019",
    "pop.00020",
    "pop.00023",
    "pop.00024",
    "pop.00025",
    "pop.00026",
    "pop.00027",
    "pop.00028",
    "pop.00029",
    "pop.00031",
    "pop.00032",
    "pop.00033",
    "pop.00034",
    "pop.00035",
    "pop.00036",
    "pop.00038",
    "pop.00039",
    "pop.00040",
    "pop.00041",
    "pop.00042",
    "pop.00044",
    "pop.00046",
    "pop.00049",
    "pop.00050",
    "pop.00080",
    "pop.00097",
    "pop.00098",
    "pop.00099",
    "reggae.00000",
    "reggae.00001",
    "reggae.00002",
    "reggae.00004",
    "reggae.00006",
    "reggae.00009",
    "reggae.00011",
    "reggae.00012",
    "reggae.00014",
    "reggae.00015",
    "reggae.00016",
    "reggae.00017",
    "reggae.00018",
    "reggae.00019",
    "reggae.00020",
    "reggae.00021",
    "reggae.00022",
    "reggae.00023",
    "reggae.00024",
    "reggae.00025",
    "reggae.00026",
    "reggae.00027",
    "reggae.00028",
    "reggae.00029",
    "reggae.00030",
    "reggae.00031",
    "reggae.00032",
    "reggae.00042",
    "reggae.00043",
    "reggae.00044",
    "reggae.00045",
    "reggae.00049",
    "reggae.00050",
    "reggae.00051",
    "reggae.00054",
    "reggae.00055",
    "reggae.00056",
    "reggae.00057",
    "reggae.00058",
    "reggae.00059",
    "reggae.00060",
    "reggae.00063",
    "reggae.00069",
    "rock.00000",
    "rock.00001",
    "rock.00002",
    "rock.00003",
    "rock.00004",
    "rock.00005",
    "rock.00006",
    "rock.00007",
    "rock.00008",
    "rock.00009",
    "rock.00016",
    "rock.00017",
    "rock.00018",
    "rock.00019",
    "rock.00020",
    "rock.00021",
    "rock.00022",
    "rock.00023",
    "rock.00024",
    "rock.00025",
    "rock.00026",
    "rock.00057",
    "rock.00058",
    "rock.00059",
    "rock.00060",
    "rock.00061",
    "rock.00062",
    "rock.00063",
    "rock.00064",
    "rock.00065",
    "rock.00066",
    "rock.00067",
    "rock.00068",
    "rock.00069",
    "rock.00070",
    "rock.00091",
    "rock.00092",
    "rock.00093",
    "rock.00094",
    "rock.00095",
    "rock.00096",
    "rock.00097",
    "rock.00098",
    "rock.00099",
]

filtered_valid = [
    "blues.00000",
    "blues.00001",
    "blues.00002",
    "blues.00003",
    "blues.00004",
    "blues.00005",
    "blues.00006",
    "blues.00007",
    "blues.00008",
    "blues.00009",
    "blues.00010",
    "blues.00011",
    "blues.00050",
    "blues.00051",
    "blues.00052",
    "blues.00053",
    "blues.00054",
    "blues.00055",
    "blues.00056",
    "blues.00057",
    "blues.00058",
    "blues.00059",
    "blues.00060",
    "classical.00000",
    "classical.00001",
    "classical.00002",
    "classical.00003",
    "classical.00004",
    "classical.00005",
    "classical.00006",
    "classical.00007",
    "classical.00008",
    "classical.00009",
    "classical.00010",
    "classical.00068",
    "classical.00069",
    "classical.00070",
    "classical.00071",
    "classical.00072",
    "classical.00073",
    "classical.00074",
    "classical.00075",
    "classical.00076",
    "country.00000",
    "country.00001",
    "country.00002",
    "country.00003",
    "country.00004",
    "country.00005",
    "country.00006",
    "country.00007",
    "country.00009",
    "country.00010",
    "country.00011",
    "country.00012",
    "country.00013",
    "country.00014",
    "country.00015",
    "country.00016",
    "country.00017",
    "country.00018",
    "country.00027",
    "country.00041",
    "country.00042",
    "country.00045",
    "country.00049",
    "disco.00000",
    "disco.00002",
    "disco.00003",
    "disco.00004",
    "disco.00006",
    "disco.00007",
    "disco.00008",
    "disco.00009",
    "disco.00010",
    "disco.00011",
    "disco.00012",
    "disco.00013",
    "disco.00014",
    "disco.00046",
    "disco.00048",
    "disco.00052",
    "disco.00067",
    "disco.00068",
    "disco.00072",
    "disco.00075",
    "disco.00090",
    "disco.00095",
    "hiphop.00081",
    "hiphop.00082",
    "hiphop.00083",
    "hiphop.00084",
    "hiphop.00085",
    "hiphop.00086",
    "hiphop.00087",
    "hiphop.00088",
    "hiphop.00089",
    "hiphop.00090",
    "hiphop.00091",
    "hiphop.00092",
    "hiphop.00093",
    "hiphop.00094",
    "hiphop.00095",
    "hiphop.00096",
    "hiphop.00097",
    "hiphop.00098",
    "jazz.00002",
    "jazz.00003",
    "jazz.00004",
    "jazz.00005",
    "jazz.00006",
    "jazz.00007",
    "jazz.00008",
    "jazz.00009",
    "jazz.00010",
    "jazz.00025",
    "jazz.00026",
    "jazz.00027",
    "jazz.00028",
    "jazz.00029",
    "jazz.00030",
    "jazz.00031",
    "jazz.00032",
    "metal.00000",
    "metal.00001",
    "metal.00006",
    "metal.00007",
    "metal.00008",
    "metal.00009",
    "metal.00010",
    "metal.00011",
    "metal.00016",
    "metal.00017",
    "metal.00018",
    "metal.00019",
    "metal.00020",
    "metal.00036",
    "metal.00037",
    "metal.00068",
    "metal.00076",
    "metal.00077",
    "metal.00081",
    "metal.00082",
    "pop.00010",
    "pop.00053",
    "pop.00055",
    "pop.00058",
    "pop.00059",
    "pop.00060",
    "pop.00061",
    "pop.00062",
    "pop.00081",
    "pop.00083",
    "pop.00084",
    "pop.00085",
    "pop.00086",
    "reggae.00061",
    "reggae.00062",
    "reggae.00070",
    "reggae.00072",
    "reggae.00074",
    "reggae.00076",
    "reggae.00077",
    "reggae.00078",
    "reggae.00085",
    "reggae.00092",
    "reggae.00093",
    "reggae.00094",
    "reggae.00095",
    "reggae.00096",
    "reggae.00097",
    "reggae.00098",
    "reggae.00099",
    "rock.00038",
    "rock.00049",
    "rock.00050",
    "rock.00051",
    "rock.00052",
    "rock.00053",
    "rock.00054",
    "rock.00055",
    "rock.00056",
    "rock.00071",
    "rock.00072",
    "rock.00073",
    "rock.00074",
    "rock.00075",
    "rock.00076",
    "rock.00077",
    "rock.00078",
    "rock.00079",
    "rock.00080",
    "rock.00081",
    "rock.00082",
    "rock.00083",
    "rock.00084",
    "rock.00085",
]


URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
FOLDER_IN_ARCHIVE = "genres"
_CHECKSUMS = {
    "http://opihi.cs.uvic.ca/sound/genres.tar.gz": "24347e0223d2ba798e0a558c4c172d9d4a19c00bb7963fe055d183dadb4ef2c6"
}


def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
    """
    Loads a file from the dataset and returns the raw waveform
    as a Torch Tensor, its sample rate as an integer, and its
    genre as a string.
    """
    # Filenames are of the form label.id, e.g. blues.00078
    label, _ = fileid.split(".")

    # Read wav
    file_audio = os.path.join(path, label, fileid + ext_audio)
    waveform, sample_rate = torchaudio.load(file_audio)

    return waveform, sample_rate, label


class GTZAN(Dataset):
    """*GTZAN* :cite:`tzanetakis_essl_cook_2001` dataset.

    Note:
        Please see http://marsyas.info/downloads/datasets.html if you are planning to use
        this dataset to publish results.

    Note:
        As of October 2022, the download link is not currently working. Setting ``download=True``
        in GTZAN dataset will result in a URL connection error.

    Args:
        root (str or Path): Path to the directory where the dataset is found or downloaded.
        url (str, optional): The URL to download the dataset from.
            (default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
        folder_in_archive (str, optional): The top-level directory of the dataset.
        download (bool, optional):
            Whether to download the dataset if it is not found at root path. (default: ``False``).
        subset (str or None, optional): Which subset of the dataset to use.
            One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
            If ``None``, the entire dataset is used. (default: ``None``).
    """

    _ext_audio = ".wav"

    def __init__(
        self,
        root: Union[str, Path],
        url: str = URL,
        folder_in_archive: str = FOLDER_IN_ARCHIVE,
        download: bool = False,
        subset: Optional[str] = None,
    ) -> None:

        # super(GTZAN, self).__init__()

        # Get string representation of 'root' in case Path object is passed
        root = os.fspath(root)

        self.root = root
        self.url = url
        self.folder_in_archive = folder_in_archive
        self.download = download
        self.subset = subset

        if subset is not None and subset not in ["training", "validation", "testing"]:
            raise ValueError("When `subset` is not None, it must be one of ['training', 'validation', 'testing'].")

        archive = os.path.basename(url)
        archive = os.path.join(root, archive)
        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url_to_file(url, archive, hash_prefix=checksum)
                _extract_tar(archive)

        if not os.path.isdir(self._path):
            raise RuntimeError("Dataset not found. Please use `download=True` to download it.")

        if self.subset is None:
            # Check every subdirectory under dataset root
            # which has the same name as the genres in
            # GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
            # This lets users remove or move around song files,
            # useful when e.g. they want to use only some of the files
            # in a genre or want to label other files with a different
            # genre.
            self._walker = []

            root = os.path.expanduser(self._path)

            for directory in gtzan_genres:
                fulldir = os.path.join(root, directory)

                if not os.path.exists(fulldir):
                    continue

                songs_in_genre = os.listdir(fulldir)
                songs_in_genre.sort()
                for fname in songs_in_genre:
                    name, ext = os.path.splitext(fname)
                    if ext.lower() == ".wav" and "." in name:
                        # Check whether the file is of the form
                        # `gtzan_genre`.`5 digit number`.wav
                        genre, num = name.split(".")
                        if genre in gtzan_genres and len(num) == 5 and num.isdigit():
                            self._walker.append(name)
        else:
            if self.subset == "training":
                self._walker = filtered_train
            elif self.subset == "validation":
                self._walker = filtered_valid
            elif self.subset == "testing":
                self._walker = filtered_test

    def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
        """Load the n-th sample from the dataset.

        Args:
            n (int): The index of the sample to be loaded

        Returns:
            Tuple of the following items;

            Tensor:
                Waveform
            int:
                Sample rate
            str:
                Label
        """
        fileid = self._walker[n]
        item = load_gtzan_item(fileid, self._path, self._ext_audio)
        waveform, sample_rate, label = item
        return waveform, sample_rate, label

    def __len__(self) -> int:
        return len(self._walker)
